Skip to content

[Backend] Refactor Transform Pipeline to support different backends#2189

Merged
SiriusNEO merged 9 commits into
tile-ai:mainfrom
SiriusNEO:chaofan/backend_0507
May 27, 2026
Merged

[Backend] Refactor Transform Pipeline to support different backends#2189
SiriusNEO merged 9 commits into
tile-ai:mainfrom
SiriusNEO:chaofan/backend_0507

Conversation

@SiriusNEO

@SiriusNEO SiriusNEO commented May 12, 2026

Copy link
Copy Markdown
Collaborator

Tracking Issue: #2115

This PR majorly refactors the pass pipeline part. In our design, each backend should have its own pipeline since there are too many hardware-specific passes for now. Currently the pipelines in different backends are similar, but they will be gradually different in the future.

Summary by CodeRabbit

  • Refactor

    • Reworked lowering into a modular, registered backend pipeline system and routed compilation through target-specific pipelines.
  • New Features

    • Added dedicated backend pipelines for CUDA, ROCm (HIP), Metal, and CPU.
    • Introduced pre-lowering semantic checks with optional AST printing and pipeline-config helpers including layout-visualization toggles.
  • Tests

    • Updated tests to exercise the new backend pipeline entrypoints.

Review Change Stack

@coderabbitai

coderabbitai Bot commented May 12, 2026

Copy link
Copy Markdown
Contributor

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9eff69ad-4ac0-4839-b434-e6ea84be9a5a

📥 Commits

Reviewing files that changed from the base of the PR and between 5a6db2c and ce3150d.

📒 Files selected for processing (6)
  • testing/python/issue/test_tilelang_issue_2123.py
  • testing/python/transform/test_tilelang_transform_inject_tcgen05_fence.py
  • testing/python/transform/test_tilelang_transform_lexical_alloc_scope.py
  • testing/python/transform/test_tilelang_transform_lower_shared_barrier.py
  • testing/python/transform/test_tilelang_transform_plan_update_buffer_allocation_location.py
  • tilelang/backend/cuda/pipeline.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/backend/cuda/pipeline.py

📝 Walkthrough

Walkthrough

Adds a Pipeline registry and per-backend pipeline bodies (CPU/CUDA/ROCm/Metal), pass-context helpers, a PreLowerSemanticCheck pass, refactors engine and JIT lowering to resolve and invoke target-specific pipelines, and updates CUDA-targeted tests to use the new CUDA prologue lowering.

Changes

Backend Pipeline Architecture and Integration

Layer / File(s) Summary
Pipeline registry and resolution
tilelang/backend/pipeline.py
Pipeline class stores backend name and lowering callable; _PIPELINES registry with register_pipeline, get_pipeline, and resolve_pipeline(target) enable target-kind–driven pipeline selection.
Pass-context configuration predicates
tilelang/backend/pipeline_utils.py
Helpers (allow_vectorize, allow_global_thread_synchronization, should_enable_aggressive_merge, should_force_let_inline, should_enable_layout_visual, should_enable_race_check, should_disable_shared_memory_reuse) query PassContext; get_layout_visual_formats() parses/validates formats; LayoutVisual(mod) conditionally applies layout visualization.
Pre-lowering semantic validation
tilelang/engine/semantic_check.py
PreLowerSemanticCheck(mod) runs optional AST printing and NestedLoopChecker/FragmentLoopChecker semantic validation, gated by configuration flags.
CPU pipeline (c/llvm backends)
tilelang/backend/cpu/__init__.py, tilelang/backend/cpu/pipeline.py
CPUPassPipelineBody applies an ordered TVM/TileLang pass sequence with conditional let-inline, race checks, and vectorization; registered for "c" and "llvm".
Common/WebGPU pipeline registration
tilelang/backend/common.py
On import, registers Pipeline("webgpu", CPUPassPipelineBody) for WebGPU backend kind.
CUDA pipeline with warp-specialization
tilelang/backend/cuda/__init__.py, tilelang/backend/cuda/pipeline.py
Adds allow_warp_specialized(pass_ctx, target) and module_has_tma(mod) helpers; CUDAPassPipelineBodyPrologue prepares/binds target and layout, CUDAPassPipelineBody completes CUDA-specific lowering with conditional TMA fusion and warp-group register annotation; registered as "cuda".
ROCm/HIP pipeline
tilelang/backend/rocm/__init__.py, tilelang/backend/rocm/pipeline.py
ROCMPassPipelineBody applies conditional let-inline and race checks, full TileLang/TVM lowering, and is registered as "hip".
Metal pipeline
tilelang/backend/metal/__init__.py, tilelang/backend/metal/pipeline.py
MetalPassPipelineBody applies target binding, conditional let-inline and race checks, full lowering sequence including layout and synchronization, and registers as "metal".
Backend package auto-registration
tilelang/backend/__init__.py
Package-level imports of cpu, common, cuda, metal, rocm trigger pipeline registration side effects on tilelang.backend import.
Engine lowering refactored to use pipelines
tilelang/engine/lower.py
Removes explicit LowerAndLegalize/OptimizeForTarget phases; runs PreLowerSemanticCheck(mod), resolves pipeline via resolve_pipeline(target), and invokes pipeline.lower(mod, target).
JIT adapter integrated with pipelines
tilelang/jit/adapter/utils.py
get_annotated_mod now uses PreLowerSemanticCheck and resolve_pipeline(target).lower(mod, target) instead of the prior explicit sequence.
CUDA-targeted transform test updates
testing/python/transform/test_tilelang_transform_inject_tcgen05_fence.py, testing/python/transform/test_tilelang_transform_lexical_alloc_scope.py, testing/python/transform/test_tilelang_transform_lower_shared_barrier.py, testing/python/transform/test_tilelang_transform_plan_update_buffer_allocation_location.py, testing/python/issue/test_tilelang_issue_2123.py
Tests and helpers updated to import and call CUDAPassPipelineBodyPrologue where LowerAndLegalize was previously used, preserving subsequent TileLang transform expectations.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • tile-ai/tilelang#2003: Introduces tcgen05-related passes; this PR updates the same tcgen05 test lowering path to use the CUDA pipeline prologue.
  • tile-ai/tilelang#1288: Adds nested-loop checker used in pre-lowering; this PR integrates a PreLowerSemanticCheck that runs the nested-loop checker.

Suggested reviewers

  • kurisu6912
  • LeiWang1999

Poem

🐇 I hop through pipelines neat and clear,

Each backend finds its lowering gear.
From prologue prep to final song,
Modular hops make passes strong.
Hooray — the rabbit says: ship it, dear!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.53% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Backend] Refactor Transform Pipeline to support different backends' accurately summarizes the main change—refactoring the transform pipeline to be backend-specific.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@SiriusNEO SiriusNEO changed the title [WIP][Backend] Refactor Pipeline to support different backends [WIP][Backend] Refactor Transform Pipeline to support different backends May 12, 2026
@SiriusNEO SiriusNEO force-pushed the chaofan/backend_0507 branch from 3241ead to 997c2e2 Compare May 12, 2026 09:52
@SiriusNEO SiriusNEO force-pushed the chaofan/backend_0507 branch from 997c2e2 to f83a7e1 Compare May 26, 2026 07:10
SiriusNEO added 2 commits May 26, 2026 17:12
Introduce Pipeline abstraction in backend/pipeline.py with per-backend
registration. Each backend (cuda, hip, c, llvm) now registers its own
compilation pass pipeline. engine/lower.py resolves the pipeline from
the target instead of hardcoding phase imports.
@SiriusNEO SiriusNEO force-pushed the chaofan/backend_0507 branch from 05cde9d to a267e6e Compare May 26, 2026 09:16
@SiriusNEO SiriusNEO force-pushed the chaofan/backend_0507 branch from a267e6e to 523136f Compare May 26, 2026 09:31
@SiriusNEO SiriusNEO marked this pull request as ready for review May 26, 2026 09:31
@SiriusNEO SiriusNEO changed the title [WIP][Backend] Refactor Transform Pipeline to support different backends [Backend] Refactor Transform Pipeline to support different backends May 26, 2026

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tilelang/backend/rocm/pipeline.py (1)

18-82: 🏗️ Heavy lift

Consider extracting a shared lowering spine with backend hook points.

This function duplicates most of tilelang.engine.pass_pipeline.LowerCommon (ordering and pass set), which will be costly to keep in sync as pipelines evolve. A shared helper with backend-specific pre/post hooks would reduce drift risk.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/backend/rocm/pipeline.py` around lines 18 - 82, The lower_amd
function duplicates the bulk of tilelang.engine.pass_pipeline.LowerCommon;
refactor by extracting the shared lowering spine into a single helper (e.g.,
LowerCommonSpine or reuse LowerCommon) that encapsulates the common sequence of
transforms and returns a mod and pass_ctx, then change lower_amd to call that
helper and only perform AMD-specific pre/post hook passes (e.g., any passes
before/after LayoutVisual, the calls that rely on pass_ctx like
VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx)) and
MergeSharedMemoryAllocations(...should_enable_aggressive_merge(pass_ctx=pass_ctx)),
and conditional branches like allow_global_thread_synchronization()); keep
unique AMD-only transforms (e.g., LowerLDGSTG, ThreadSync("shared.dyn"),
MarkCudaSyncCalls(False), LowerDeviceKernelLaunch) in lower_amd and delegate the
rest to the shared helper so ordering stays centralized.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tilelang/backend/pipeline.py`:
- Around line 29-36: The register_pipeline function currently overwrites
existing entries in the _PIPELINES mapping when pipeline.name already exists;
change register_pipeline to first check if pipeline.name is already present in
_PIPELINES and, instead of silently replacing it, raise a clear exception (e.g.,
ValueError) or log/error and refuse registration so callers must explicitly
unregister/replace before adding; reference the register_pipeline function, the
_PIPELINES mapping, and pipeline.name when locating where to add the existence
check and error handling.

In `@tilelang/engine/pass_pipeline.py`:
- Around line 68-85: The code treats "all" specially only when formats_str ==
"all", so mixed inputs like "txt,all" end up forwarding the literal "all";
update the parsing logic so after building formats_list from formats_str (the
variable formats_list), if "all" appears expand/replace it with the canonical
list ["txt","png","pdf","svg"] (or return that union) before performing the
invalid_formats check; ensure valid_formats used for validation does not treat
"all" as a valid concrete format (keep "all" only as a special token) and
validate against the canonical formats, referencing formats_str, formats_list,
valid_formats, invalid_formats and the TL_LAYOUT_VISUALIZATION_FORMATS setting.

---

Nitpick comments:
In `@tilelang/backend/rocm/pipeline.py`:
- Around line 18-82: The lower_amd function duplicates the bulk of
tilelang.engine.pass_pipeline.LowerCommon; refactor by extracting the shared
lowering spine into a single helper (e.g., LowerCommonSpine or reuse
LowerCommon) that encapsulates the common sequence of transforms and returns a
mod and pass_ctx, then change lower_amd to call that helper and only perform
AMD-specific pre/post hook passes (e.g., any passes before/after LayoutVisual,
the calls that rely on pass_ctx like
VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx)) and
MergeSharedMemoryAllocations(...should_enable_aggressive_merge(pass_ctx=pass_ctx)),
and conditional branches like allow_global_thread_synchronization()); keep
unique AMD-only transforms (e.g., LowerLDGSTG, ThreadSync("shared.dyn"),
MarkCudaSyncCalls(False), LowerDeviceKernelLaunch) in lower_amd and delegate the
rest to the shared helper so ordering stays centralized.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 425c0493-ca4b-4cec-bf94-f136d7920f16

📥 Commits

Reviewing files that changed from the base of the PR and between 928c942 and 523136f.

📒 Files selected for processing (16)
  • testing/python/transform/_transform_testing_utils.py
  • testing/python/transform/test_tilelang_transform_inject_tcgen05_fence.py
  • testing/python/transform/test_tilelang_transform_lexical_alloc_scope.py
  • testing/python/transform/test_tilelang_transform_lower_shared_barrier.py
  • testing/python/transform/test_tilelang_transform_plan_update_buffer_allocation_location.py
  • tilelang/backend/__init__.py
  • tilelang/backend/common.py
  • tilelang/backend/cpu/__init__.py
  • tilelang/backend/cuda/__init__.py
  • tilelang/backend/cuda/pipeline.py
  • tilelang/backend/pipeline.py
  • tilelang/backend/rocm/__init__.py
  • tilelang/backend/rocm/pipeline.py
  • tilelang/engine/lower.py
  • tilelang/engine/pass_pipeline.py
  • tilelang/jit/adapter/utils.py

Comment on lines +29 to +36
def register_pipeline(pipeline: Pipeline) -> Pipeline:
"""Register a lowering pipeline for a backend.

The pipeline name should match ``target.kind.name`` (e.g. ``"cuda"``,
``"hip"``, ``"c"``, ``"llvm"``).
"""
_PIPELINES[pipeline.name] = pipeline
return pipeline

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Prevent silent pipeline overwrite on duplicate backend registration.

At Line 35, duplicate backend names overwrite prior registrations without any signal. That can silently switch the lowering path for a target and make behavior depend on import/registration order.

Suggested fix
 def register_pipeline(pipeline: Pipeline) -> Pipeline:
@@
-    _PIPELINES[pipeline.name] = pipeline
+    if pipeline.name in _PIPELINES:
+        raise ValueError(
+            f"Pipeline '{pipeline.name}' is already registered."
+        )
+    _PIPELINES[pipeline.name] = pipeline
     return pipeline
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/backend/pipeline.py` around lines 29 - 36, The register_pipeline
function currently overwrites existing entries in the _PIPELINES mapping when
pipeline.name already exists; change register_pipeline to first check if
pipeline.name is already present in _PIPELINES and, instead of silently
replacing it, raise a clear exception (e.g., ValueError) or log/error and refuse
registration so callers must explicitly unregister/replace before adding;
reference the register_pipeline function, the _PIPELINES mapping, and
pipeline.name when locating where to add the existence check and error handling.

Comment thread tilelang/engine/pass_pipeline.py Outdated

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tilelang/backend/cuda/pipeline.py (1)

20-29: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Guard against None target in allow_warp_specialized (CUDA pipeline)

  • In tilelang/backend/cuda/pipeline.py lines 20-29, target: Target | None is forwarded to is_cuda_target(target) and have_tma(target) without checking for None.
  • is_cuda_target (tilelang/jit/adapter/utils.py) directly reads target.kind.name (no None handling), and have_tma (tilelang/contrib/nvcc.py) similarly reads target.kind.name (no None handling), so target=None would raise.
  • Even if current call sites pass a non-None target, the function’s signature implies it can be called without one—add if target is None: return False (and mirror in tilelang/engine/phase.py) or tighten the signature to require Target.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/backend/cuda/pipeline.py` around lines 20 - 29, The function
allow_warp_specialized currently forwards target to is_cuda_target and have_tma
without guarding for None; add an early check "if target is None: return False"
at the top of allow_warp_specialized so neither is_cuda_target nor have_tma are
called with None, and mirror the same guard in the related entry
(tilelang/engine/phase.py) or alternately change the signature to require a
non-None Target; keep the existing pass_ctx handling and the config key
"tl.disable_warp_specialized" unchanged.
tilelang/engine/pass_pipeline.py (1)

24-27: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Fix crash in CommonPassPipelineBody: should_enable_aggressive_merge called with unsupported target=
tilelang/engine/pass_pipeline.py defines should_enable_aggressive_merge(pass_ctx=...) with no target parameter, but CommonPassPipelineBody calls it as should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target), which will raise TypeError and abort lowering. (There’s a target-aware version of this helper in tilelang/engine/phase.py.)

Proposed minimal fix
-    enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target)
+    enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/engine/pass_pipeline.py` around lines 24 - 27,
CommonPassPipelineBody is calling
should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) but the current
helper signature def should_enable_aggressive_merge(pass_ctx: PassContext | None
= None) lacks a target parameter and raises TypeError; update
should_enable_aggressive_merge to accept an optional target parameter (e.g.,
target: str | None = None) and either forward to the target-aware implementation
in tilelang.engine.phase or use target when selecting the config, ensuring the
function still works when called without target and preserves existing behavior
when pass_ctx is None by obtaining the PassContext via
tilelang.transform.get_pass_context().
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tilelang/backend/cuda/pipeline.py`:
- Around line 33-38: The docstring currently references the pass name
"LowerAndLegalize" but the pipeline actually calls LowerTileOp and relies on the
tl.has_tma attribute set by that pass; update the docstring to mention
LowerTileOp instead of LowerAndLegalize so it matches the implementation (the
check that reads tl.has_tma set by LowerTileOp). Ensure the text explicitly
names LowerTileOp and keeps the rest of the explanation intact.

---

Outside diff comments:
In `@tilelang/backend/cuda/pipeline.py`:
- Around line 20-29: The function allow_warp_specialized currently forwards
target to is_cuda_target and have_tma without guarding for None; add an early
check "if target is None: return False" at the top of allow_warp_specialized so
neither is_cuda_target nor have_tma are called with None, and mirror the same
guard in the related entry (tilelang/engine/phase.py) or alternately change the
signature to require a non-None Target; keep the existing pass_ctx handling and
the config key "tl.disable_warp_specialized" unchanged.

In `@tilelang/engine/pass_pipeline.py`:
- Around line 24-27: CommonPassPipelineBody is calling
should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) but the current
helper signature def should_enable_aggressive_merge(pass_ctx: PassContext | None
= None) lacks a target parameter and raises TypeError; update
should_enable_aggressive_merge to accept an optional target parameter (e.g.,
target: str | None = None) and either forward to the target-aware implementation
in tilelang.engine.phase or use target when selecting the config, ensuring the
function still works when called without target and preserves existing behavior
when pass_ctx is None by obtaining the PassContext via
tilelang.transform.get_pass_context().
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9cabcc67-b996-4519-9ac5-3e0652c54d54

📥 Commits

Reviewing files that changed from the base of the PR and between 523136f and d450693.

📒 Files selected for processing (4)
  • tilelang/backend/common.py
  • tilelang/backend/cuda/pipeline.py
  • tilelang/backend/rocm/pipeline.py
  • tilelang/engine/pass_pipeline.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tilelang/backend/common.py
  • tilelang/backend/rocm/pipeline.py

Comment on lines +33 to +38
"""Check if any function in the module was lowered with TMA operations.

This reads the ``tl.has_tma`` attribute set by ``LowerTileOp`` during
``LowerAndLegalize``, which is the source of truth for whether TMA
copies were actually generated.
"""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Update docstring to match actual pass name.

The docstring mentions LowerAndLegalize, but the actual pipeline calls LowerTileOp() at line 87. Update the docstring to reference the correct pass name.

📝 Proposed fix
-    This reads the ``tl.has_tma`` attribute set by ``LowerTileOp`` during
-    ``LowerAndLegalize``, which is the source of truth for whether TMA
-    copies were actually generated.
+    This reads the ``tl.has_tma`` attribute set by ``LowerTileOp``,
+    which is the source of truth for whether TMA copies were actually
+    generated.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/backend/cuda/pipeline.py` around lines 33 - 38, The docstring
currently references the pass name "LowerAndLegalize" but the pipeline actually
calls LowerTileOp and relies on the tl.has_tma attribute set by that pass;
update the docstring to mention LowerTileOp instead of LowerAndLegalize so it
matches the implementation (the check that reads tl.has_tma set by LowerTileOp).
Ensure the text explicitly names LowerTileOp and keeps the rest of the
explanation intact.

@SiriusNEO

Copy link
Copy Markdown
Collaborator Author

@regression-perf

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tilelang/backend/metal/pipeline.py (1)

23-28: 💤 Low value

Consider passing pass_ctx consistently to predicate functions.

pass_ctx is retrieved on line 21 but not passed to should_force_let_inline() and should_enable_race_check(), while allow_vectorize(), should_enable_aggressive_merge(), and should_disable_shared_memory_reuse() later in the function receive it explicitly. If these functions internally call get_pass_context() anyway, this is fine; otherwise, consider passing pass_ctx for consistency.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/backend/metal/pipeline.py` around lines 23 - 28, pass_ctx is
retrieved earlier but not forwarded to the predicate calls
should_force_let_inline() and should_enable_race_check(); update those calls to
accept and pass the already-obtained pass_ctx for consistency with
allow_vectorize(pass_ctx), should_enable_aggressive_merge(pass_ctx), and
should_disable_shared_memory_reuse(pass_ctx) so all predicate functions use the
same context (i.e., replace should_force_let_inline() with
should_force_let_inline(pass_ctx) and should_enable_race_check() with
should_enable_race_check(pass_ctx)).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tilelang/backend/pipeline_utils.py`:
- Around line 55-80: The get_layout_visual_formats function currently allows
"all" to appear inside comma-separated lists (e.g., "txt,all,png"); update the
validation so that after splitting into formats_list you explicitly reject "all"
when formats_list has more than one item by raising a ValueError (keep the
existing expansion behavior when formats_str == "all"); include a clear error
message referencing TL_LAYOUT_VISUALIZATION_FORMATS and that "all" must be used
alone, and keep the existing invalid_formats check for other invalid tokens (use
function name get_layout_visual_formats and variable formats_list to locate
where to add the extra check).

---

Nitpick comments:
In `@tilelang/backend/metal/pipeline.py`:
- Around line 23-28: pass_ctx is retrieved earlier but not forwarded to the
predicate calls should_force_let_inline() and should_enable_race_check(); update
those calls to accept and pass the already-obtained pass_ctx for consistency
with allow_vectorize(pass_ctx), should_enable_aggressive_merge(pass_ctx), and
should_disable_shared_memory_reuse(pass_ctx) so all predicate functions use the
same context (i.e., replace should_force_let_inline() with
should_force_let_inline(pass_ctx) and should_enable_race_check() with
should_enable_race_check(pass_ctx)).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ab2aa29b-da72-4347-9932-ed9dda584bb5

📥 Commits

Reviewing files that changed from the base of the PR and between d450693 and 0bb482e.

📒 Files selected for processing (13)
  • testing/python/transform/_transform_testing_utils.py
  • tilelang/backend/__init__.py
  • tilelang/backend/common.py
  • tilelang/backend/cpu/__init__.py
  • tilelang/backend/cpu/pipeline.py
  • tilelang/backend/cuda/pipeline.py
  • tilelang/backend/metal/__init__.py
  • tilelang/backend/metal/pipeline.py
  • tilelang/backend/pipeline_utils.py
  • tilelang/backend/rocm/pipeline.py
  • tilelang/engine/lower.py
  • tilelang/engine/semantic_check.py
  • tilelang/jit/adapter/utils.py
✅ Files skipped from review due to trivial changes (2)
  • tilelang/backend/metal/init.py
  • tilelang/backend/cpu/init.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • tilelang/backend/init.py
  • tilelang/jit/adapter/utils.py
  • tilelang/engine/lower.py
  • testing/python/transform/_transform_testing_utils.py
  • tilelang/backend/cuda/pipeline.py

Comment on lines +55 to +80
def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "")
if not formats_value:
return ["txt"]

formats_str = formats_value.strip().lower()
valid_formats = ["txt", "png", "pdf", "svg", "all"]

if formats_str == "all":
return ["txt", "png", "pdf", "svg"]

if "," in formats_str:
formats_list = [f.strip() for f in formats_str.split(",")]
else:
formats_list = [formats_str]

invalid_formats = [f for f in formats_list if f not in valid_formats]
if invalid_formats:
raise ValueError(
f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {invalid_formats}. "
f"Valid formats are: {valid_formats}. "
f"You can choose one of the valid formats or a comma-separated list of formats.(e.g., 'txt,png,pdf')"
)
return formats_list

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Validate that "all" is not mixed with other formats.

The validation logic allows "all" to appear in a comma-separated list (e.g., "txt,all,png"), but "all" should only be valid as a standalone value. When used alone, it expands to all formats (line 66), but when mixed with other formats in a list, it passes validation and is returned as-is, which the downstream LayoutVisual function likely cannot handle.

🛡️ Proposed fix to reject "all" in comma-separated lists
     if formats_str == "all":
         return ["txt", "png", "pdf", "svg"]
 
     if "," in formats_str:
         formats_list = [f.strip() for f in formats_str.split(",")]
     else:
         formats_list = [formats_str]
 
+    if "all" in formats_list:
+        raise ValueError(
+            "The format 'all' cannot be used in a comma-separated list. "
+            "Use 'all' alone to enable all formats, or specify individual formats (e.g., 'txt,png,pdf')."
+        )
+
     invalid_formats = [f for f in formats_list if f not in valid_formats]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "")
if not formats_value:
return ["txt"]
formats_str = formats_value.strip().lower()
valid_formats = ["txt", "png", "pdf", "svg", "all"]
if formats_str == "all":
return ["txt", "png", "pdf", "svg"]
if "," in formats_str:
formats_list = [f.strip() for f in formats_str.split(",")]
else:
formats_list = [formats_str]
invalid_formats = [f for f in formats_list if f not in valid_formats]
if invalid_formats:
raise ValueError(
f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {invalid_formats}. "
f"Valid formats are: {valid_formats}. "
f"You can choose one of the valid formats or a comma-separated list of formats.(e.g., 'txt,png,pdf')"
)
return formats_list
def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "")
if not formats_value:
return ["txt"]
formats_str = formats_value.strip().lower()
valid_formats = ["txt", "png", "pdf", "svg", "all"]
if formats_str == "all":
return ["txt", "png", "pdf", "svg"]
if "," in formats_str:
formats_list = [f.strip() for f in formats_str.split(",")]
else:
formats_list = [formats_str]
if "all" in formats_list:
raise ValueError(
"The format 'all' cannot be used in a comma-separated list. "
"Use 'all' alone to enable all formats, or specify individual formats (e.g., 'txt,png,pdf')."
)
invalid_formats = [f for f in formats_list if f not in valid_formats]
if invalid_formats:
raise ValueError(
f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {invalid_formats}. "
f"Valid formats are: {valid_formats}. "
f"You can choose one of the valid formats or a comma-separated list of formats.(e.g., 'txt,png,pdf')"
)
return formats_list
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/backend/pipeline_utils.py` around lines 55 - 80, The
get_layout_visual_formats function currently allows "all" to appear inside
comma-separated lists (e.g., "txt,all,png"); update the validation so that after
splitting into formats_list you explicitly reject "all" when formats_list has
more than one item by raising a ValueError (keep the existing expansion behavior
when formats_str == "all"); include a clear error message referencing
TL_LAYOUT_VISUALIZATION_FORMATS and that "all" must be used alone, and keep the
existing invalid_formats check for other invalid tokens (use function name
get_layout_visual_formats and variable formats_list to locate where to add the
extra check).

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
testing/python/transform/test_tilelang_transform_inject_tcgen05_fence.py (1)

5-5: ⚡ Quick win

Keep these tests on a bounded CUDA lowering helper.

Importing and invoking CUDAPassPipelineBody here makes the assertions depend on every later CUDA backend pass, not just the tcgen05_ld/st lowering these tests are trying to validate. Since this PR explicitly allows backend pipelines to diverge, unrelated pass reordering can now break these tests even when the transform under test is still correct. Please route these through the narrower CUDA test helper that stops at the required lowering stage instead of the full backend pipeline.

Also applies to: 121-122, 169-170

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@testing/python/transform/test_tilelang_transform_inject_tcgen05_fence.py` at
line 5, The tests import and invoke CUDAPassPipelineBody which runs the full
CUDA backend pipeline and makes assertions fragile; replace the direct
use/import of CUDAPassPipelineBody with the project’s bounded CUDA test helper
that stops at the tcgen05_ld/st lowering stage (i.e., import and call the
narrower CUDA lowering helper instead of CUDAPassPipelineBody), update the
import and call sites accordingly, and make the same replacement for the other
occurrences in this file so the tests only exercise the intended lowering pass.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@testing/python/transform/test_tilelang_transform_inject_tcgen05_fence.py`:
- Line 5: The tests import and invoke CUDAPassPipelineBody which runs the full
CUDA backend pipeline and makes assertions fragile; replace the direct
use/import of CUDAPassPipelineBody with the project’s bounded CUDA test helper
that stops at the tcgen05_ld/st lowering stage (i.e., import and call the
narrower CUDA lowering helper instead of CUDAPassPipelineBody), update the
import and call sites accordingly, and make the same replacement for the other
occurrences in this file so the tests only exercise the intended lowering pass.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: df2cb2bd-6f18-4c60-9d50-25f93bfe06d1

📥 Commits

Reviewing files that changed from the base of the PR and between 0bb482e and 5a6db2c.

📒 Files selected for processing (2)
  • testing/python/transform/test_tilelang_transform_inject_tcgen05_fence.py
  • tilelang/backend/cuda/pipeline.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/backend/cuda/pipeline.py

@SiriusNEO

Copy link
Copy Markdown
Collaborator Author

@regression-perf

@github-actions

Copy link
Copy Markdown

Performance Regression Test Report

Triggered by: @SiriusNEO
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/26508674573

Results

File Original Latency Current Latency Speedup
example_topk 30.7115 38.749 0.792576
example_mha_inference 0.0598024 0.0612215 0.976821
example_gqa_fwd_bshd 0.0508283 0.051761 0.981981
example_tilelang_gemm_fp8_2xAcc 0.0904924 0.0920889 0.982663
example_gqa_bwd_tma_reduce_varlen 0.0330189 0.0335133 0.985246
example_mhc_pre 0.144142 0.145649 0.98965
example_dequant_gemm_bf16_fp4_hopper 0.395269 0.399246 0.990039
example_linear_attn_fwd 0.028434 0.0287008 0.990704
example_mla_decode 0.312996 0.315874 0.990887
example_dequant_gemm_bf16_mxfp4_hopper 0.359442 0.362693 0.991037
sparse_mla_bwd 0.233308 0.234978 0.992892
example_group_per_split_token_cast_to_fp8 0.00760793 0.00765804 0.993457
example_blocksparse_gemm 0.0137402 0.0138259 0.993802
example_convolution_autotune 0.723049 0.727446 0.993956
example_gqa_sink_bwd_bhsd_sliding_window 0.0180777 0.0181614 0.99539
fp8_lighting_indexer 0.0226158 0.0226998 0.9963
example_mha_sink_fwd_bhsd_sliding_window 0.0126538 0.0126896 0.997183
example_tilelang_block_sparse_attn 0.00723865 0.0072575 0.997401
example_mha_bwd_bshd 0.0288429 0.0289161 0.997467
example_linear_attn_bwd 0.116619 0.116841 0.998104
example_tilelang_gemm_splitk 0.768497 0.76978 0.998333
example_mha_fwd_bshd 0.0190721 0.0191019 0.998442
example_tilelang_sparse_gqa_decode_varlen_indice 0.01188 0.0118942 0.998813
example_mha_sink_fwd_bhsd 0.0126647 0.0126796 0.998819
sparse_mla_fwd 0.082648 0.082744 0.99884
example_warp_specialize_gemm_softpipe_stage2 0.019549 0.0195609 0.999393
example_mha_fwd_varlen 0.0327817 0.0327944 0.999615
example_gqa_decode 0.0413738 0.0413844 0.999744
example_gemm_intrinsics 0.0253567 0.025362 0.999789
example_convolution 0.92221 0.922289 0.999914
example_warp_specialize_gemm_copy_1_gemm_0 0.019555 0.0195501 1.00025
example_tilelang_nsa_decode 0.00550905 0.00550739 1.0003
example_elementwise_add 0.113036 0.113 1.00032
example_tilelang_sparse_gqa_decode_varlen_mask 0.0128829 0.012877 1.00046
example_warp_specialize_gemm_barrierpipe_stage2 0.0294206 0.0294057 1.00051
example_tilelang_nsa_fwd 0.00527823 0.00527555 1.00051
example_per_token_cast_to_fp8 0.00652024 0.00651558 1.00072
example_dequant_gemv_fp16xint4 0.0269897 0.0269642 1.00094
example_warp_specialize_gemm_copy_0_gemm_1 0.0269532 0.0269247 1.00106
example_dynamic 0.49918 0.498382 1.0016
example_gqa_bwd 0.0329455 0.032875 1.00214
sparse_mla_fwd_pipelined 0.0592501 0.0591167 1.00226
example_mhc_post 0.106467 0.106216 1.00236
example_gemm_autotune 0.0162498 0.0162105 1.00242
example_mha_sink_bwd_bhsd_sliding_window 0.0383302 0.0382343 1.00251
topk_selector 0.0414186 0.0413046 1.00276
example_fusedmoe_tilelang 0.095861 0.0954196 1.00463
example_mha_fwd_bhsd 0.00913095 0.00908838 1.00468
block_sparse_attn_tilelang 0.00672394 0.00669155 1.00484
example_gemm 0.0171467 0.0170429 1.00609
example_gemv 0.202618 0.20125 1.0068
example_tilelang_gemm_splitk_vectorize_atomicadd 0.790441 0.784361 1.00775
example_mha_bwd_bhsd 0.0297886 0.029549 1.00811
example_mha_sink_bwd_bhsd 0.0526467 0.052202 1.00852
example_tilelang_gemm_fp8 0.240038 0.237939 1.00882
example_gqa_sink_bwd_bhsd 0.0302829 0.0300085 1.00914
example_vertical_slash_sparse_attn 0.167518 0.165128 1.01448
example_dequant_gemm_fp4_hopper 0.724106 0.00896101 80.8064
example_dequant_gemm_w4a8 3.83191 0.00406223 943.301

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@SiriusNEO

Copy link
Copy Markdown
Collaborator Author

@regression-perf

@github-actions

Copy link
Copy Markdown

Performance Regression Test Report

Triggered by: @SiriusNEO
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/26514293621

Results

File Original Latency Current Latency Speedup
example_dequant_gemm_bf16_mxfp4_hopper 0.356834 0.367039 0.972197
example_mha_sink_bwd_bhsd 0.0517051 0.0531508 0.9728
example_mha_fwd_bhsd 0.00910245 0.00933941 0.974628
example_linear_attn_bwd 0.116455 0.117748 0.989023
example_mhc_pre 0.145161 0.146565 0.990422
example_mha_bwd_bshd 0.0293442 0.0296261 0.990484
example_warp_specialize_gemm_copy_1_gemm_0 0.0194253 0.0195879 0.991701
example_gqa_sink_bwd_bhsd 0.0300057 0.0302338 0.992456
example_gemv 0.201252 0.202564 0.993525
example_dynamic 0.4952 0.498212 0.993954
example_warp_specialize_gemm_copy_0_gemm_1 0.0268782 0.0270337 0.994247
sparse_mla_bwd 0.227897 0.228851 0.995831
example_gqa_decode 0.0414309 0.0416009 0.995915
example_dequant_gemm_bf16_fp4_hopper 0.397407 0.398934 0.996172
example_linear_attn_fwd 0.0285454 0.0286551 0.996174
example_convolution 0.919111 0.922135 0.996721
example_gqa_sink_bwd_bhsd_sliding_window 0.0181029 0.0181624 0.996725
example_mha_fwd_varlen 0.0328866 0.0329932 0.996769
example_convolution_autotune 0.732648 0.734851 0.997003
example_mha_sink_fwd_bhsd_sliding_window 0.0126863 0.0127128 0.997917
example_mha_sink_bwd_bhsd_sliding_window 0.0386221 0.0386968 0.99807
example_gqa_fwd_bshd 0.0516175 0.0517155 0.998105
example_group_per_split_token_cast_to_fp8 0.00760232 0.00761357 0.998523
example_tilelang_sparse_gqa_decode_varlen_mask 0.0128482 0.0128663 0.998599
example_mhc_post 0.106401 0.106545 0.998652
example_gemm 0.0170591 0.0170812 0.998707
example_per_token_cast_to_fp8 0.00651274 0.00652106 0.998725
example_fusedmoe_tilelang 0.0956688 0.0957855 0.998782
example_elementwise_add 0.112977 0.113061 0.999253
example_gemm_intrinsics 0.0253552 0.0253704 0.999401
block_sparse_attn_tilelang 0.00672286 0.00672642 0.99947
example_gqa_bwd 0.0329108 0.0329257 0.999546
example_gqa_bwd_tma_reduce_varlen 0.0338355 0.0338472 0.999653
example_mha_inference 0.0601991 0.0602064 0.999879
example_mla_decode 0.31588 0.315872 1.00002
topk_selector 0.0413754 0.04135 1.00062
example_tilelang_sparse_gqa_decode_varlen_indice 0.011865 0.0118553 1.00082
example_tilelang_block_sparse_attn 0.00724566 0.00723863 1.00097
example_tilelang_gemm_fp8 0.239562 0.239185 1.00158
example_dequant_gemv_fp16xint4 0.0269897 0.0269447 1.00167
sparse_mla_fwd 0.0828156 0.0826686 1.00178
example_tilelang_gemm_splitk 0.768909 0.767434 1.00192
example_blocksparse_gemm 0.0137478 0.0137173 1.00222
example_tilelang_nsa_decode 0.00552103 0.00550567 1.00279
example_mha_sink_fwd_bhsd 0.0127978 0.0127597 1.00299
example_warp_specialize_gemm_softpipe_stage2 0.0194069 0.0193439 1.00326
example_topk 30.697 30.5652 1.00431
example_gemm_autotune 0.0162389 0.0161641 1.00463
sparse_mla_fwd_pipelined 0.0596489 0.0593343 1.0053
example_tilelang_nsa_fwd 0.00529706 0.00526705 1.0057
example_vertical_slash_sparse_attn 0.167316 0.166341 1.00586
fp8_lighting_indexer 0.0227778 0.0226205 1.00695
example_mha_fwd_bshd 0.0190954 0.0189421 1.00809
example_mha_bwd_bhsd 0.0299026 0.0294632 1.01491
example_warp_specialize_gemm_barrierpipe_stage2 0.0296913 0.0291868 1.01728
example_tilelang_gemm_splitk_vectorize_atomicadd 0.789447 0.775431 1.01808
example_tilelang_gemm_fp8_2xAcc 0.0922395 0.0894955 1.03066
example_dequant_gemm_fp4_hopper 0.723855 0.00895699 80.8145
example_dequant_gemm_w4a8 3.82084 0.00404925 943.591

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@SiriusNEO SiriusNEO merged commit 0a9b651 into tile-ai:main May 27, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant